#!/usr/bin/python3
# -*- coding: utf-8 -*-
#
# Experiment Script
# Copyright (C) 2016 by Thomas Dreibholz
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either ipVersion 3 of the License, or
# (at your option) any later ipVersion.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Contact: dreibh@simula.no

import sys
import os
import subprocess
import time
import datetime
import hashlib

# NorNet
from NorNetSiteSetup         import *;
from NorNetTools             import *;
from NorNetAPI               import *;
from NorNetExperimentToolbox import *;


MeasurementName = 'hybla-buffer-v6'
SliceName       = 'hu_zhoufeng'
Runtime         = 60
PortBase        = 12002
SSHPrivateKey   = '/home/zhoufeng/.ssh/id_rsa'   # !!! Set your key here! !!!
LocalNodeName   = "fisketorget.uib.nornet"
RemoteNodeName  = "watson.ku.nornet"

# ###### Get variable names #################################################
def getVariableNames():
   return 'TimeStamp ' + \
          'FromNodeIndex FromNode ToNodeIndex ToNode ' + \
          'FromSite FromSiteIndex ToSite ToSiteIndex ' + \
          'FromProvider FromProviderIndex ToProvider ToProviderIndex ' + \
          'SndBufSize RcvBufSize CC CMT PathMgr NDiffPorts IPVersion ' + \
          'ScalarFileNameBase VectorFileNameBase'


# ###### Get variable settings ##############################################
def getVariableSettings(now, localSite, localNode, localProvider, remoteSite, remoteNode, remoteProvider,
                        ipVersion, sndBufSize, rcvBufSize, cc, cmt, pathMgr, nDiffPorts,
                        scalarFileNameBase = None, vectorFileNameBase = None, hashResult = False):
   
   variableSettings = \
      str(now)                      + ' '  + \
      str(localNode['node_index'])  + ' '  + \
      '"' + localNode['node_name']  + '" '  + \
      str(remoteNode['node_index']) + ' '  + \
      '"' + remoteNode['node_name'] + '" '  + \
      '"' + localSite['site_utf8']  + '" '  + \
      str(localSite['site_index'])  + ' '  + \
      '"' + remoteSite['site_utf8'] + '" '  + \
      str(remoteSite['site_index']) + ' '  + \
      '"' + localProvider['provider_short_name']           + '" ' + \
      '"' + str(localProvider['provider_index'])  + '" ' + \
      '"' + remoteProvider['provider_short_name']          + '" ' + \
      '"' + str(remoteProvider['provider_index']) + '" ' + \
      str(sndBufSize) + ' ' + str(rcvBufSize) + ' ' + \
      '"' + str(cc)       + '" ' + \
      '"' + str(cmt)      + '" ' + \
      '"' + str(pathMgr)  + '" ' + \
      str(nDiffPorts) + ' '  + \
      str(ipVersion)

   if hashResult == True:
      fp = hashlib.sha1()
      fp.update(variableSettings.encode('utf-8'))
      return fp.hexdigest()
   else:
      if scalarFileNameBase != None:
         variableSettings = variableSettings + ' "' + os.path.basename(scalarFileNameBase) + '"'
      if vectorFileNameBase != None:
         variableSettings = variableSettings + ' "' + os.path.basename(vectorFileNameBase) + '"'
      return variableSettings


# ###### Get scalar file name ###############################################
def getScalarName(now, localSite, localNode, localProvider, remoteSite, remoteNode, remoteProvider,
                  ipVersion, sndBufSize, rcvBufSize, cc, cmt, pathMgr, nDiffPorts):
   return MeasurementName + '/' + \
      getVariableSettings(now, localSite, localNode, localProvider, remoteSite, remoteNode, remoteProvider,
                          ipVersion, sndBufSize, rcvBufSize, cc, cmt, pathMgr, nDiffPorts,
                          None, None, True) + \
      '.sca.bz2'



# ###### Run the measurement ################################################
def runMeasurement(now, slice, summaryFile,
                   localSite, localNode, localProvider, remoteSite, remoteNode, remoteProvider,
                   ipVersion, sndBufSize, rcvBufSize, cc, cmt, pathMgr, nDiffPorts):

   remoteAddress = makeAddress(remoteSite, remoteNode, remoteProvider, ipVersion, slice)
   if cmt != 'off':   # ------ Multi-path transport -------------------------
      remotePort       = makePort(PortBase, remoteSite, remoteNode, None, ipVersion, slice)
      localAddress     = makeAddress(localSite, localNode, None, ipVersion, slice)
      # NOTE: Ping will record the initial path's RTTs!
      pingLocalAddress = makeAddress(localSite, localNode, localProvider, ipVersion, slice)
   else:              # ------ Single-path transport ------------------------
      remotePort       = makePort(PortBase, remoteSite, remoteNode, remoteProvider, ipVersion, slice)
      localAddress     = makeAddress(localSite, localNode, localProvider, ipVersion, slice)
      pingLocalAddress = localAddress

   # ------ Get parameter configuration -------------------
   scalarFileName = getScalarName(now, localSite, localNode, localProvider, remoteSite, remoteNode, remoteProvider,
                                  ipVersion, sndBufSize, rcvBufSize, cc, cmt, pathMgr, nDiffPorts)
   vectorFileName = scalarFileName.replace('.sca.bz2','.vec.bz2')

   variables      = getVariableSettings(now, localSite, localNode, localProvider, remoteSite, remoteNode, remoteProvider,
                                        ipVersion, sndBufSize, rcvBufSize, cc, cmt, pathMgr, nDiffPorts,
                                        scalarFileName, vectorFileName)

   activeScalarFileName  = scalarFileName.replace('.sca.bz2', '-active.sca.bz2')
   passiveScalarFileName = scalarFileName.replace('.sca.bz2', '-passive.sca.bz2')
   pingFileName = scalarFileName.replace('.sca.bz2', '.data')

   summaryFile.write('--values="' + variables + '"\n')
   summaryFile.write('--input=' + activeScalarFileName + '\n')
   summaryFile.write('--values="' + variables + '"\n')
   summaryFile.write('--input=' + passiveScalarFileName + '\n')

   # ------ Run NetPerfMeter ------------------------------
   if rcvBufSize == -1:
      rcvBufSize = 0

   if remoteAddress.version == 4:
      ping = 'ping'
   else:
      ping = 'ping6'

   cmdLine = 'mkdir -p ' + MeasurementName + ' && \\\n' + \
             '( sudo ' + ping + ' -n -c' + str(Runtime) + ' -I ' + str(pingLocalAddress) + " " + str(remoteAddress) + '>' + pingFileName + ' & ) && \\\n' + \
             'netperfmeter [' + str(remoteAddress) + ']:' + str(remotePort) + ' ' + \
             '-local=[' + str(localAddress) + '] ' + \
             '-scalar=' + scalarFileName + ' ' + \
             '-vector=' + vectorFileName + ' \\\n\t' + \
             '-runtime=' + str(Runtime) + ' ' + \
             '-verbosity=1 ' + \
             '-tcp const0:const1400:const0:const0:sndbuf=' + str(sndBufSize) + \
             ':rcvbuf=' + str(rcvBufSize) + \
             ':cc=' + cc + \
             ':pathmgr=' + pathMgr + \
             ':ndiffports=' + str(nDiffPorts) + \
             ':cmt=' + cmt
   return runOverSSH(SSHPrivateKey, localNode, slice, cmdLine, True)



# ###########################################################################
# #### Main program                                                      ####
# ###########################################################################

# ====== Get configuration ==================================================
loginToPLC()
fullSiteList = fetchNorNetSiteList()
fillNodeList = fetchNorNetNodeList()

ExperimentNodes = []
localNode = fetchNorNetNode(LocalNodeName)
if localNode == None:
   error('Node not found: ' + LocalNodeName)
   sys.exit(0)
ExperimentNodes.append(localNode)

remoteNode = fetchNorNetNode(RemoteNodeName)
if remoteNode == None:
   error('Node not found: ' + RemoteNodeName)
   sys.exit(0)
ExperimentNodes.append(remoteNode)

Slice = fetchNorNetSlice(SliceName)
if Slice == None:
   error('Slice not found: ' + SliceName)


# ====== Test or install NetPerfMeter from sources? =========================
if len(sys.argv) != 1:
   if sys.argv[1] == 'test':
      testCustomNetPerfMeter(SSHPrivateKey, ExperimentNodes, Slice)
   elif sys.argv[1] == 'init':
      installCustomNetPerfMeter(SSHPrivateKey, ExperimentNodes, Slice)
   else:
      error('Unknown command: ' + sys.argv[1])
   sys.exit(0)


# ====== Prepare ============================================================
print('\x1b[34;1m###### Stage 1: Preparations ######\x1b[0m')

# ------ Prepare results output ---------------------------
try:
   os.mkdir(MeasurementName)
except:
   pass
fullSummaryFileName = MeasurementName + '/results.summary'
summaryFileName     = MeasurementName + '/results.summary.new'
try:
   os.remove(summaryFileName)
except:
   pass
summaryFile = codecs.open(summaryFileName, 'w+', 'utf-8')
summaryFile.write('--varnames=' + getVariableNames() + '\n')
summaryFile.write('--resultsdirectory=' + MeasurementName + '\n')


# ====== Start passive side ==============================================
print('\x1b[34;1m###### Stage 2: Starting passive side ######\x1b[0m')

pathMgr = 'fullmesh'

processes = []
for node in ExperimentNodes:
  newProcess = startServer(fullSiteList, PortBase, MeasurementName,
                           SSHPrivateKey, node, Slice, pathMgr)
  if newProcess != None:
     processes.append(newProcess)
for process in processes:
  process.wait()
time.sleep(10)


# ====== Create runs ========================================================
print('\x1b[34;1m###### Stage 3: Running measurements ######\x1b[0m')
allRuns = []
localSite          = getNorNetSiteOfNode(fullSiteList, localNode)
localProviderList  = getNorNetProvidersForSite(localSite)
remoteSite         = getNorNetSiteOfNode(fullSiteList, remoteNode)
remoteProviderList = getNorNetProvidersForSite(remoteSite)

for repetition in range(1, 2):   # <<-- NOTE: Better do *not* use this!!! Instead, run this script multiple times in sequence!
   now = int(time.time()) + repetition   # <<-- NOTE: "now" *must* be updated in each run (to ensure distinct file names)!

   for localProviderIndex in localProviderList:
      localProvider = localProviderList[localProviderIndex]
      for remoteProviderIndex in remoteProviderList:
         remoteProvider = remoteProviderList[remoteProviderIndex]
         for ipVersion in [ 6 ]:   # 4, 6, 46   (4=IPv4-only, 6=IPv6-only, 46=IPv4+IPv6)

            print(localNode['node_utf8'] + '/' + localProvider['provider_short_name'] + \
                  ' -> ' + remoteNode['node_utf8'] + '/' + remoteProvider['provider_short_name'] + \
                  ' over IPv' + str(ipVersion))
            for bufsize in range(0,20000,1000):
               for cmt in ['mptcp']:   # 'off'
                  cc         = 'hybla'
                  sndBufSize = bufsize * 1024
                  rcvBufSize = sndBufSize
                  pathMgr    = 'fullmesh'
                  nDiffPorts = 1
                  allRuns.append([now, localSite, localNode, localProvider, remoteSite, remoteNode, remoteProvider, ipVersion,sndBufSize, rcvBufSize, cc, cmt, pathMgr, nDiffPorts])

runNumber = 0
totalRuns = len(allRuns)
while len(allRuns) > 0:
   parallelRuns   = []
   allocatedNodes = []
   for run in allRuns:
      if not ((run[2]['node_name'] in allocatedNodes) or (run[5]['node_name'] in allocatedNodes)):
         allocatedNodes.append(run[2]['node_name'])
         allocatedNodes.append(run[5]['node_name'])
         parallelRuns.append(run)

   print('###### PARALLEL ######')
   processes = []
   for stage in [ 1, 2, 3 ]:
      for run in parallelRuns:
         now            = run[0]
         localSite      = run[1]
         localNode      = run[2]
         localProvider  = run[3]
         remoteSite     = run[4]
         remoteNode     = run[5]
         remoteProvider = run[6]
         ipVersion      = run[7]
         sndBufSize     = run[8]
         rcvBufSize     = run[9]
         cc             = run[10]
         cmt            = run[11]
         pathMgr        = run[12]
         nDiffPorts     = run[13]

         if stage == 1:
           runNumber = runNumber + 1
           allRuns.remove(run)
           print('====== Run ' + str(runNumber) + ' of ' + str(totalRuns) + ' ======')

         elif stage == 2:
            newProcess = runMeasurement(now, Slice, summaryFile,
                                        localSite, localNode, localProvider, remoteSite, remoteNode, remoteProvider,
                                        ipVersion, sndBufSize, rcvBufSize, cc, cmt, pathMgr, nDiffPorts)
            if newProcess != None:
               processes.append(newProcess)

         elif stage == 3:
            for process in processes:
               process.wait()
            processes = []


# ====== Collect results ====================================================
print('\x1b[34;1m###### Stage 4: Results Collection ######\x1b[0m')
summaryFile.close()
for node in ExperimentNodes:
   copyFromNodeOverRSync(SSHPrivateKey, node, Slice, MeasurementName)
runLocal('cat ' + summaryFileName + ' >' + fullSummaryFileName + '.updated && if [ -e ' + fullSummaryFileName + ' ] ; then grep -v ^--varnames= ' + fullSummaryFileName + ' | grep -v ^--resultsdirectory= >>' + fullSummaryFileName + '.updated ; fi && mv ' + fullSummaryFileName + '.updated ' + fullSummaryFileName)

# Finally, run the "createsummary" program.
runLocal('createsummary xy -ignore-scalar-file-errors <' + MeasurementName + '/results.summary')


# ====== Clean up ===========================================================
print('\x1b[34;1m###### Stage 5: Clean-up ######\x1b[0m')
processes = []
for node in ExperimentNodes:
   newProcess = stopServer(MeasurementName, SSHPrivateKey, node, Slice)
   if newProcess != None:
      processes.append(newProcess)
for process in processes:
   process.wait()
